import contextlib
import os
import threading
from dataclasses import dataclass
from enum import Enum, IntEnum
from typing import Dict, List

import torch
from torch.nn import functional as F

from tensorrt_llm._utils import TensorWrapper, convert_to_torch_tensor
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.math_utils import ceil_div, pad_up
from tensorrt_llm.quantization.utils import fp4_utils

is_torch_compiling_flag = False
is_piecewise_running_flag = False

aux_stream_name_list = [
    'Attention',
    'MoeShared',
    'MoeChunkingOverlap',
    'MoeBalancer',
]
AuxStreamType = Enum(
    'AuxStreamType',
    aux_stream_name_list,
)
EventType = Enum(
    'EventType',
    ['Main', *aux_stream_name_list],
    start=0,
)


# IMPORTANT: Keep the same order of activation functions in this enum and the enum in
# cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h
class ActivationType(IntEnum):
    InvalidType = 0
    Identity = 1
    Gelu = 2
    Relu = 3
    Silu = 4
    Swiglu = 5
    Geglu = 6
    SwigluBias = 7
    Relu2 = 8


# IMPORTANT: when adding a new activation type, please update this function.
# And make sure it aligned with cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h::isGatedActivation function.
def is_gated_activation(activation_type: ActivationType) -> bool:
    return activation_type in [
        ActivationType.Swiglu, ActivationType.SwigluBias, ActivationType.Geglu
    ]


def set_torch_compiling(enable: bool):
    global is_torch_compiling_flag
    is_torch_compiling_flag = enable


def is_torch_compiling() -> bool:
    global is_torch_compiling_flag
    return is_torch_compiling_flag


def set_piecewise_running(enable: bool):
    global is_piecewise_running_flag
    is_piecewise_running_flag = enable


def is_piecewise_running() -> bool:
    global is_piecewise_running_flag
    return is_piecewise_running_flag


_global_attrs = threading.local()


def get_global_attrs():
    return _global_attrs


_model_extra_attrs = threading.local()


def get_model_extra_attrs():
    return getattr(_model_extra_attrs, 'attrs', None)


@contextlib.contextmanager
def model_extra_attrs(attrs: Dict):
    old_attrs = getattr(_model_extra_attrs, 'attrs', None)
    _model_extra_attrs.attrs = attrs
    try:
        yield
    finally:
        _model_extra_attrs.attrs = old_attrs


def with_model_extra_attrs(get_attrs):

    def decorator(func):

        def wrapper(self, *args, **kwargs):
            with model_extra_attrs(get_attrs(self)):
                return func(self, *args, **kwargs)

        return wrapper

    return decorator


def make_weak_ref(x):

    if isinstance(x, torch.Tensor):
        return convert_to_torch_tensor(
            TensorWrapper(x.data_ptr(), x.dtype, x.shape,
                          x.stride())) if x.is_cuda else x
    elif isinstance(x, tuple):
        return tuple(make_weak_ref(i) for i in x)
    elif isinstance(x, list):
        return [make_weak_ref(i) for i in x]
    elif isinstance(x, dict):
        return {k: make_weak_ref(v) for k, v in x.items()}
    elif isinstance(x, (int, float, bool)):
        return x
    else:
        raise TypeError(f"Invalid type {type(x)} to make weak ref")


@dataclass
class Fp4QuantizedTensor:
    fp4_tensor: torch.Tensor
    scaling_factor: torch.Tensor
    is_sf_swizzled: bool = True

    @property
    def shape(self):
        return self.fp4_tensor.shape


def compute_swizzled_sf_shape(row: int, col: int):
    padded_row = pad_up(row, 128)
    padded_col = pad_up(col, 4)
    return padded_row, padded_col


def swizzle_sf(sf: torch.Tensor,
               rows: int,
               cols: int,
               scaling_vector_size: int = 16):
    """Swizzle FP4 scaling factors using C++ torch op implementation
    Args:
        sf: [b, rows, cols_sf] or [rows, cols_sf]. The original unswizzled scaling factors.
        rows: rows of the original unquantized tensor
        cols_sf: ceil_div(cols, scaling_vector_size) where cols is the number of columns of the original unquantized tensor
        scaling_vector_size: the size of the scaling vector
    Returns:
        [b * pad_up(rows, 128) * pad_up(cols_sf, 4), ] 1D swizzled scaling factors, possibly with rows and cols padded.
    """
    sf_cols = ceil_div(cols, scaling_vector_size)
    sf = sf.view(-1, rows, sf_cols)
    return torch.ops.trtllm.block_scale_interleave(sf)


def unswizzle_sf(sf: torch.Tensor,
                 rows: int,
                 cols: int,
                 scaling_vector_size: int = 16):
    """Swizzle FP4 scaling factors using C++ torch op implementation
    Args:
        sf: The (padded and) swizzled scaling factors.
        rows: rows of the original unquantized tensor
        cols: cols of the original unquantized tensor
        scaling_vector_size: the size of the scaling vector
    Returns:
        2D unswizzled scaling factors
    """
    sf_cols = ceil_div(cols, scaling_vector_size)
    sf = sf.view(-1, rows, sf_cols)
    return torch.ops.trtllm.block_scale_interleave_reverse(sf).view(-1, sf_cols)


@torch.library.custom_op("trtllm::reswizzle_sf", mutates_args=())
def reswizzle_sf(sf: torch.Tensor,
                 rows: int,
                 cols: int,
                 scaling_vector_size: int = 16) -> torch.Tensor:
    """Reswizzle FP4 scaling factors using C++ torch op implementation.
       It unswizzles the scaling factors in each partition first, then concatenates them together, and finally swizzles them back.
    Args:
        sf: The (padded and) swizzled scaling factors.
        rows: rows of the original unquantized tensor
        cols: cols of the original unquantized tensor
        scaling_vector_size: the size of the scaling vector
    Returns:
        1D reswizzled scaling factors
    """
    sf_cols = ceil_div(cols, scaling_vector_size)
    padded_rows, padded_sf_cols = compute_swizzled_sf_shape(rows, sf_cols)
    padded_cols = padded_sf_cols * scaling_vector_size

    assert sf.numel() % (padded_rows * padded_sf_cols) == 0
    num_partitions = sf.numel() // (padded_rows * padded_sf_cols)

    sf_reshaped = sf.view(num_partitions, padded_rows, padded_sf_cols)

    # Unswizzle each partition
    sf_unswizzled = unswizzle_sf(sf_reshaped, padded_rows, padded_cols,
                                 scaling_vector_size)

    # Brings the unswizzled scaling factors in each partition together
    total_rows = num_partitions * rows
    sf_unswizzled = sf_unswizzled.view(num_partitions, padded_rows,
                                       padded_sf_cols)
    sf_concatenated = sf_unswizzled[:, :rows, :sf_cols].contiguous(
    )  # TODO: This will incur a elementwise kernel
    sf_concatenated = sf_concatenated.view(total_rows, sf_cols)

    # Finally swizzle the concatenated scaling factors
    return swizzle_sf(sf_concatenated, total_rows, cols, scaling_vector_size)


@torch.library.register_fake("trtllm::reswizzle_sf")
def _(sf, rows, cols, scaling_vector_size=16):
    sf_cols = ceil_div(cols, scaling_vector_size)
    padded_rows, padded_sf_cols = compute_swizzled_sf_shape(rows, sf_cols)
    num_partitions = sf.numel() // (padded_rows * padded_sf_cols)
    total_rows = num_partitions * rows
    sz = pad_up(total_rows, 128) * pad_up(cols, 4)
    return sf.new_empty(sz)


def next_positive_power_of_2(x: int) -> int:
    if x < 1:
        return 1

    # Following code is equivalent to 1 << (x - 1).bit_length()
    # But this impl does not contain bit_length() so can be used by torch compile.
    # It can correctly handle 64bit number which should be enough for now.
    n = x - 1
    n |= n >> 1
    n |= n >> 2
    n |= n >> 4
    n |= n >> 8
    n |= n >> 16
    n |= n >> 32
    return n + 1


def last_positive_power_of_2(x: int) -> int:
    next = next_positive_power_of_2(x)
    if next == x:
        return next

    return next // 2


def nearest_in_buckets(x: int, buckets: List[int]) -> int:
    return min(max(next_positive_power_of_2(x), buckets[0]), buckets[-1])


def get_power_of_2_num_tokens_buckets(max_num_tokens) -> List[int]:
    max_num_tokens = next_positive_power_of_2(max_num_tokens)
    num_token_buckets = []
    m = max_num_tokens
    while m >= 1:
        num_token_buckets.append(m)
        m //= 2

    return tuple(num_token_buckets[::-1])


def get_last_power_of_2_num_tokens_buckets(max_num_tokens) -> List[int]:
    max_num_tokens = last_positive_power_of_2(max_num_tokens)
    num_token_buckets = []
    m = max_num_tokens
    while m >= 1:
        num_token_buckets.append(m)
        m //= 2
    return tuple(num_token_buckets[::-1])


def fp4_scale_infer_shape(input_shapes: List[List[int]]):
    """Calculate the dimensions of the fp4 scale tensor.
    """
    out_shape, scale_shape = fp4_utils.get_fp4_shape(input_shapes[0],
                                                     sf_vec_size=16)
    return scale_shape * 2


_enable_piecewise_cuda_graph = True


def set_piecewise_cuda_graph_flag(enable: bool):
    global _enable_piecewise_cuda_graph
    _enable_piecewise_cuda_graph = enable


def get_piecewise_cuda_graph_flag() -> bool:
    global _enable_piecewise_cuda_graph
    return _enable_piecewise_cuda_graph


@contextlib.contextmanager
def piecewise_cuda_graph(enable: bool):
    prev_enable = get_piecewise_cuda_graph_flag()
    set_piecewise_cuda_graph_flag(enable)
    try:
        yield
    finally:
        set_piecewise_cuda_graph_flag(prev_enable)


def set_per_request_piecewise_cuda_graph_flag(enable: bool):
    _global_attrs.per_request_piecewise_cuda_graph_flag = enable


def get_per_request_piecewise_cuda_graph_flag() -> bool:
    return getattr(_global_attrs, 'per_request_piecewise_cuda_graph_flag', True)


def create_lm_head_tp_mapping(mapping: Mapping, token_count: int) -> Mapping:
    # We use heuristic to determine the lm_head_tp_size
    # Since token_count=256 will hit the boundary of math-bound problem
    # We use 256 // token_count to determine the lm_head_tp_size
    # For more details, refer to the blog: https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/blogs/tech_blog/blog14_Scaling_Expert_Parallelism_in_TensorRT-LLM_part3.md#mtp-lm-head-tensor-parallelism
    lm_head_tp_size_raw = 256 // token_count
    # TODO: On platforms like GB200, setting lm_head_tp_size_upper_bound to world_size could be more efficient when world_size > gpus_per_node, we need to do further investigation.
    lm_head_tp_size_upper_bound = min(mapping.world_size, mapping.gpus_per_node)
    lm_head_tp_size = int(
        os.getenv(
            'LM_HEAD_TP_SIZE',
            nearest_in_buckets(lm_head_tp_size_raw,
                               [1, lm_head_tp_size_upper_bound])))
    assert mapping.tp_size % lm_head_tp_size == 0, f"mapping.tp_size: {mapping.tp_size}, lm_head_tp_size: {lm_head_tp_size}"
    lm_head_pp_size = mapping.pp_size * mapping.tp_size // lm_head_tp_size

    return Mapping(
        world_size=lm_head_tp_size * lm_head_pp_size,
        rank=mapping.rank,
        gpus_per_node=mapping.gpus_per_node,
        tp_size=lm_head_tp_size,
        pp_size=lm_head_pp_size,
        enable_attention_dp=mapping.enable_attention_dp,
        enable_lm_head_tp_in_adp=mapping.enable_lm_head_tp_in_adp,
    )


def get_device_uuid(device_idx: int) -> str:
    """Get the UUID of a CUDA device using torch cuda api"""

    property = torch.cuda.get_device_properties(device_idx)
    uuid = "GPU-" + str(property.uuid)
    return uuid


def maybe_compile(func=None, **compile_kwargs):
    """
    Conditionally compile a function with torch.compile.
    If is_piecewise_running() is True, the function will not be compiled to avoid host overhead in attention op.
    Args:
        func: The function to decorate (optional, for direct decoration).
        **compile_kwargs: Keyword arguments for torch.compile.
    Returns:
        The conditionally compiled function..
    """

    def decorator(f):
        compiled_func = torch.compile(f, **compile_kwargs)

        def wrapper(*args, **kwargs):
            if is_piecewise_running():
                return f(*args, **kwargs)
            return compiled_func(*args, **kwargs)

        return wrapper

    return decorator(func) if func else decorator


def split(x: torch.Tensor,
          tp_size: int,
          idx: int,
          dim: int = 0) -> torch.Tensor:
    assert x.shape[dim] % tp_size == 0
    split_size = x.shape[dim] // tp_size
    if tp_size == 1:
        return x
    return torch.split(x, split_size, dim=dim)[idx]


def relu2(x: torch.Tensor) -> torch.Tensor:
    return torch.square(F.relu(x))
